1   /*
2    * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
3    *
4    * This code is free software; you can redistribute it and/or modify it
5    * under the terms of the GNU General Public License version 2 only, as
6    * published by the Free Software Foundation.
7    *
8    * This code is distributed in the hope that it will be useful, but WITHOUT
9    * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
10   * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
11   * version 2 for more details (a copy is included in the LICENSE file that
12   * accompanied this code).
13   *
14   * You should have received a copy of the GNU General Public License version
15   * 2 along with this work; if not, write to the Free Software Foundation,
16   * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
17   *
18   * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
19   * or visit www.oracle.com if you need additional information or have any
20   * questions.
21   */
22  
23  /*
24   * This file is available under and governed by the GNU General Public
25   * License version 2 only, as published by the Free Software Foundation.
26   * However, the following notice accompanied the original version of this
27   * file:
28   *
29   * Written by Doug Lea with assistance from members of JCP JSR-166
30   * Expert Group and released to the public domain, as explained at
31   * http://creativecommons.org/publicdomain/zero/1.0/
32   */
33  
34  /*
35   * @test
36   * @bug 6785442
37   * @summary Checks race between poll and remove(Object), while
38   * occasionally moonlighting as a microbenchmark.
39   * @run main RemovePollRace 12345
40   */
41  
42  import java.util.concurrent.ArrayBlockingQueue;
43  import java.util.concurrent.ConcurrentHashMap;
44  import java.util.concurrent.ConcurrentLinkedDeque;
45  import java.util.concurrent.ConcurrentLinkedQueue;
46  import java.util.concurrent.CountDownLatch;
47  import java.util.concurrent.LinkedBlockingDeque;
48  import java.util.concurrent.LinkedBlockingQueue;
49  import java.util.concurrent.LinkedTransferQueue;
50  import java.util.concurrent.atomic.AtomicLong;
51  import java.util.ArrayList;
52  import java.util.Collection;
53  import java.util.Collections;
54  import java.util.List;
55  import java.util.Queue;
56  import java.util.Map;
57  
58  public class RemovePollRace {
59      // Suitable for benchmarking.  Overriden by args[0] for testing.
60      int count = 1024 * 1024;
61  
62      final Map<String,String> results = new ConcurrentHashMap<String,String>();
63  
64      Collection<Queue<Boolean>> concurrentQueues() {
65          List<Queue<Boolean>> queues = new ArrayList<Queue<Boolean>>();
66          queues.add(new ConcurrentLinkedDeque<Boolean>());
67          queues.add(new ConcurrentLinkedQueue<Boolean>());
68          queues.add(new ArrayBlockingQueue<Boolean>(count, false));
69          queues.add(new ArrayBlockingQueue<Boolean>(count, true));
70          queues.add(new LinkedBlockingQueue<Boolean>());
71          queues.add(new LinkedBlockingDeque<Boolean>());
72          queues.add(new LinkedTransferQueue<Boolean>());
73  
74          // Following additional implementations are available from:
75          // http://gee.cs.oswego.edu/dl/concurrency-interest/index.html
76          // queues.add(new SynchronizedLinkedListQueue<Boolean>());
77  
78          // Avoid "first fast, second slow" benchmark effect.
79          Collections.shuffle(queues);
80          return queues;
81      }
82  
83      void prettyPrintResults() {
84          List<String> classNames = new ArrayList<String>(results.keySet());
85          Collections.sort(classNames);
86          int maxClassNameLength = 0;
87          int maxNanosLength = 0;
88          for (String name : classNames) {
89              if (maxClassNameLength < name.length())
90                  maxClassNameLength = name.length();
91              if (maxNanosLength < results.get(name).length())
92                  maxNanosLength = results.get(name).length();
93          }
94          String format = String.format("%%%ds %%%ds nanos/item%%n",
95                                        maxClassNameLength, maxNanosLength);
96          for (String name : classNames)
97              System.out.printf(format, name, results.get(name));
98      }
99  
100     void test(String[] args) throws Throwable {
101         if (args.length > 0)
102             count = Integer.valueOf(args[0]);
103         // Warmup
104         for (Queue<Boolean> queue : concurrentQueues())
105             test(queue);
106         results.clear();
107         for (Queue<Boolean> queue : concurrentQueues())
108             test(queue);
109 
110         prettyPrintResults();
111     }
112 
113     void await(CountDownLatch latch) {
114         try { latch.await(); }
115         catch (InterruptedException e) { unexpected(e); }
116     }
117 
118     void test(final Queue<Boolean> q) throws Throwable {
119         long t0 = System.nanoTime();
120         final int SPINS = 5;
121         final AtomicLong removes = new AtomicLong(0);
122         final AtomicLong polls = new AtomicLong(0);
123         final int adderCount =
124             Math.max(1, Runtime.getRuntime().availableProcessors() / 4);
125         final int removerCount =
126             Math.max(1, Runtime.getRuntime().availableProcessors() / 4);
127         final int pollerCount = removerCount;
128         final int threadCount = adderCount + removerCount + pollerCount;
129         final CountDownLatch startingGate = new CountDownLatch(1);
130         final CountDownLatch addersDone = new CountDownLatch(adderCount);
131         final Runnable remover = new Runnable() {
132             public void run() {
133                 await(startingGate);
134                 int spins = 0;
135                 for (;;) {
136                     boolean quittingTime = (addersDone.getCount() == 0);
137                     if (q.remove(Boolean.TRUE))
138                         removes.getAndIncrement();
139                     else if (quittingTime)
140                         break;
141                     else if (++spins > SPINS) {
142                         Thread.yield();
143                         spins = 0;
144                     }}}};
145         final Runnable poller = new Runnable() {
146             public void run() {
147                 await(startingGate);
148                 int spins = 0;
149                 for (;;) {
150                     boolean quittingTime = (addersDone.getCount() == 0);
151                     if (q.poll() == Boolean.TRUE)
152                         polls.getAndIncrement();
153                     else if (quittingTime)
154                         break;
155                     else if (++spins > SPINS) {
156                         Thread.yield();
157                         spins = 0;
158                     }}}};
159         final Runnable adder = new Runnable() {
160             public void run() {
161                 await(startingGate);
162                 for (int i = 0; i < count; i++) {
163                     for (;;) {
164                         try { q.add(Boolean.TRUE); break; }
165                         catch (IllegalStateException e) { Thread.yield(); }
166                     }
167                 }
168                 addersDone.countDown();
169             }};
170 
171         final List<Thread> adders   = new ArrayList<Thread>();
172         final List<Thread> removers = new ArrayList<Thread>();
173         final List<Thread> pollers  = new ArrayList<Thread>();
174         for (int i = 0; i < adderCount; i++)
175             adders.add(checkedThread(adder));
176         for (int i = 0; i < removerCount; i++)
177             removers.add(checkedThread(remover));
178         for (int i = 0; i < pollerCount; i++)
179             pollers.add(checkedThread(poller));
180 
181         final List<Thread> allThreads = new ArrayList<Thread>();
182         allThreads.addAll(removers);
183         allThreads.addAll(pollers);
184         allThreads.addAll(adders);
185 
186         for (Thread t : allThreads)
187             t.start();
188         startingGate.countDown();
189         for (Thread t : allThreads)
190             t.join();
191 
192         String className = q.getClass().getSimpleName();
193         long elapsed = System.nanoTime() - t0;
194         int nanos = (int) ((double) elapsed / (adderCount * count));
195         results.put(className, String.valueOf(nanos));
196         if (removes.get() + polls.get() != adderCount * count) {
197             String msg = String.format
198                 ("class=%s removes=%s polls=%d count=%d",
199                  className, removes.get(), polls.get(), count);
200             fail(msg);
201         }
202     }
203 
204     //--------------------- Infrastructure ---------------------------
205     volatile int passed = 0, failed = 0;
206     void pass() {passed++;}
207     void fail() {failed++; Thread.dumpStack();}
208     void fail(String msg) {System.err.println(msg); fail();}
209     void unexpected(Throwable t) {failed++; t.printStackTrace();}
210     void check(boolean cond) {if (cond) pass(); else fail();}
211     void equal(Object x, Object y) {
212         if (x == null ? y == null : x.equals(y)) pass();
213         else fail(x + " not equal to " + y);}
214     public static void main(String[] args) throws Throwable {
215         new RemovePollRace().instanceMain(args);}
216     public void instanceMain(String[] args) throws Throwable {
217         try {test(args);} catch (Throwable t) {unexpected(t);}
218         System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
219         if (failed > 0) throw new AssertionError("Some tests failed");}
220     Thread checkedThread(final Runnable r) {
221         return new Thread() {public void run() {
222             try {r.run();} catch (Throwable t) {unexpected(t);}}};}
223 }